Dr. Thomas Camminady
  • It’s me, hi!
  • Blog
  • Data sets

Nutriscore analysis

analysis
nutriscore
Author

Thomas Camminady

Published

June 29, 2023

Code
# %load_ext autoreload
# %autoreload 2

import os
import polars as pl
import altair as alt
from camminapy.plot import altair_theme, altair_setup
from rich import print
from blog import logger
from blog.io import (
    FilterGermany,
    ValidCode,
    ConvertNutrientsToFloats,
    HasNutriScore,
    ConvertDatetimeStrings,
    KeepOnlyEnglishVersion,
    KeepOnlyTagsVersion,
    ConvertNutriScoresToCapitalLetters,
)

alt.data_transformers.disable_max_rows()
altair_theme()
# altair_setup()
logger.setLevel("ERROR")

load_from_disk = True
final_df_path = (
    "/Users/thomascamminady/Repos/blog/contents/posts/nutri-score/df.parquet"
)
if load_from_disk and os.path.exists(final_df_path):
    df = pl.read_parquet(final_df_path)
else:
    path_csv = "en.openfoodfacts.org.products.csv"
    path_parquet = path_csv.replace(".csv", ".parquet")
    if os.path.exists(path_parquet):
        logger.info("Reading from .parquet file")
        _df = pl.read_parquet("en.openfoodfacts.org.products.parquet")
    else:
        logger.info("Reading from .csv file.")
        _df = pl.read_csv(path_csv, separator="\t", ignore_errors=True)
        _df.write_parquet(path_parquet)

    df = (
        _df.pipe(FilterGermany)  # We focus on products sold in Germany.
        .pipe(HasNutriScore)  # We want products that have a non-null Nutriscore.
        .pipe(ConvertNutrientsToFloats)  # Change dtype of columns with nutrients.
        .pipe(ConvertDatetimeStrings)  # Change dtype to be Datetime.
        .pipe(KeepOnlyEnglishVersion)  # Some duplicate columns, remove unwanted.
        .pipe(KeepOnlyTagsVersion)  # Some duplicate columns, remove unwanted.
        .pipe(ValidCode)  # Product code needs to be non-null and unique.
        .pipe(ConvertNutriScoresToCapitalLetters)
    )
    df.write_parquet(final_df_path)

n_non_nutrients = len([c for c in df.columns if not c.endswith("_100g")])
n_nutrients = len([c for c in df.columns if c.endswith("_100g")])
logger.info(f"""Number of columns excluding nutrients: {n_non_nutrients}""")
logger.info(f"""Number of nutrients columns: {n_nutrients}""")

The Data

The data is taken from Open Food Facts where it is available for download under the Open Database License.

I downloaded en.openfoodfacts.org.products.csv which is and 8.2GB file, containing 2.9 million products with 203 attributes stored per product, such as name, origin, sugar per 100g, etc.

To be able to work with this data, a couple of pre-processing steps and selections are executed which we will layout here with justifications if necessary:

  • Products must be sold in Germany. (I am German and I wanted to at least know some of the products. It also made my life easier because it reduces the amount of data by roughly a factor of 10.)
  • Products must have a valid Nutriscore. (The whole point of this analysis is to have a look at the Nutriscore, so items with a null Nutriscore are discarded. This reduces the data by another 3x.)
  • Remove redundant information. (Some columns are redundant as they contain the same content, just in a different format. This drops about 25 columns.)
  • Remove products with non-unique IDs. (This drops another couple of rows.)

Executing all these steps, we are left with around 73k products and 175 columns corresponding to different data fields. Out of those 175 columns, 118 contain information about the nutrients (per 100g).

Here’s how the number of products split up across the different Nutriscores.

Code
df.groupby("nutriscore_grade", maintain_order=True).count().sort("nutriscore_grade")
shape: (5, 2)
nutriscore_grade count
str u32
"A" 11402
"B" 10033
"C" 16877
"D" 21253
"E" 13742

It is important to note, that products are listed in the data base multiple times, often because they occur in different portion sizes. For example, here are all products that are named “Snickers”.

Code
df.filter(pl.col("product_name") == "Snickers").select("product_name", "quantity")
shape: (8, 2)
product_name quantity
str str
"Snickers" "300 g, 6 barre…
"Snickers" "42g"
"Snickers" "250g"
"Snickers" "350 g (7 * 50 …
"Snickers" "275g"
"Snickers" "50g"
"Snickers" "49g"
"Snickers" "275 g"

Nutriscore 101

The Nutriscore shall serve as a tool that allows to quickly compare products. The score can be one of “A”, “B”, “C”, “D”, or “E”. However, there is also an actual numeric score underpinning the grade.

(Note that while the official Nutriscore has a green-to-red scale, we are using a blue-to-red scale to ensure better readability for people with colorblindness.)

This is shown in the next graph which presents the histogram of these numeric scores, colored by their Nutriscore grade.

Code
c = "nutriscore_grade"
color = (
    alt.Color(f"{c}:N")
    .scale(
        zero=False,
        domain=["A", "B", "C", "D", "E"],
        # scheme="darkmulti",
        range=["blue", "lightblue", "gold", "orange", "red"],
    )
    .legend(columns=1, symbolLimit=0, labelLimit=0)
)

alt.Chart(
    df.select("nutriscore_grade", "nutriscore_score")  # .sample(5_000)
).mark_bar().encode(
    x=alt.X("nutriscore_score:Q").bin(step=1),
    y=alt.Y("count():Q"),
    color=color,
    row="nutriscore_grade:N",
).properties(
    height=150, width=700
).resolve_scale(
    y="independent"
)

This mostly agrees with what I could find online

Nutriscore.jpeg

Interestingly, there are a number of products that have an “E” grade but a lower numeric Nutriscore. Looking at a random selection of those, it seems that these are mostly beverages containing sugar.

Code
print(
    df.filter(pl.col("nutriscore_grade") == "E")
    .filter(pl.col("nutriscore_score") < 18)
    .select("product_name", "categories_en")
    .sort("categories_en")
    .sample(n=20)["product_name"]
    .to_list()
)
[
    'Rhabarber Nektar',
    'Der Eiskaffee (klassisch)',
    'Fritz-kola',
    'Sardinenfiletd',
    'Ice Tea Mango Maracuja',
    'Paradiso Orangennektar',
    'Cocktail de fruits',
    'Frucht Tiger, Apfel Erdbeere',
    'Limonata Rossa',
    'Energy Drink Fruity Power Typ Heidelbeere',
    'Bitter Lemon',
    'Monster Energy Nitro',
    'Saft',
    'Vanillemilch',
    'Die Limo, Dark Berries + Guarana',
    'babylove Stillsaft',
    'Ice Tea',
    'Ramune',
    'Premium Direktsaft Traube',
    'Saft - Apfelsaft'
]

Next, we can have a look at some of the nutrients. Let’s plot fat per 100g against sugar per 100g and group everything by Nutriscore.

Code
x = "sugars_100g"
y = "fat_100g"


chart = (
    alt.Chart(
        df.filter(pl.col("product_name").is_null().is_not()).select(x, y, c)
        # .sample(5_000)
    )
    .mark_point(clip=True, filled=True, opacity=0.2)
    .encode(
        x=alt.X(f"{x}:Q").scale(domain=(0, 100)),
        y=alt.Y(f"{y}:Q").scale(domain=(0, 100)),
        color=color,
        tooltip=["product_name:N"],
    )
    .properties(width=220, height=300)
    .facet(facet=f"{c}:N", columns=3)
)

chart

Unsuprisingly, “D” and “E” grade products seem to contain more sugar and fat.

Let’s look at a histogram of the contained calories (per 100g) next.

Code
alt.Chart(
    df.select("nutriscore_grade", "nutriscore_score", "energy-kcal_100g").filter(
        pl.col("energy-kcal_100g") < 1000
    )  # .sample(5_000)
).mark_bar().encode(
    x=alt.X("energy-kcal_100g:Q").bin(step=10),
    y=alt.Y("count():Q"),
    color=color,
    row="nutriscore_grade:N",
).properties(
    height=150, width=700
).resolve_scale(
    y="independent"
)

Interestingly, there is a spike at around 300kcal for the “A” grade foods, let’s look at what they are.

Code
df.filter(pl.col("nutriscore_grade") == "A").filter(
    pl.col("energy-kcal_100g").is_between(300, 350)
).select("product_name", "categories_en").sample(20)["product_name"].to_list()
['Conchiglioni Nudeln',
 'Wow kakao',
 'Soczewica',
 'Couscous',
 'Bio-Dinkel-Vollkornpasta Fusilli',
 'Soja-Medaillons',
 'Nudeln Kichererbsen',
 'rote Linsen Strozzapreti',
 'Hörnli',
 None,
 'Barilla Integrale Fussili',
 'Rapunzel Sportler Brei',
 'Funghi Misti',
 'Superior Pizza-Mix Mischung zur Zubereitung von Pizza',
 'Dinkel Couscous',
 'Vollkorn Fusilli',
 'Nudeln Eiernudeln Spiralen',
 'Reis',
 'Basmatireis weiß',
 'Spaghetti Vollkorn']

This is containing a lot of pasta, cereal and rice, but also flour.

Let’s have a look at the top ten healthiest foods, according to their Nutriscore.

Code
df.sort("nutriscore_score").head(10)["product_name"].to_list()
['Veggy Love Mexican',
 'Schwarze Bohnen-Tempeh Natur',
 'Wachtelbohnen-Tempeh Natur',
 'junge Erbsen',
 'Morchel-Hüte',
 'Veggie Love Orient',
 'Kapucijners',
 'Kichererbsen Bohnen-Tempeh Natur',
 'Gemüse Erbsen',
 'Knoblauch granuliert']

And of course the ten least healthy foods.

Code
df.sort("nutriscore_score").tail(10)["product_name"].to_list()
['Weiße Schokolade',
 'Chai latte Classic India - Typ Vanille-Zimt',
 'Schweizer Bio-Alpenvollmilchschokolade',
 'chai latte',
 'Chocolat Bourbon Vanille',
 'chai latte',
 'Zarte weiße Schokolade',
 'Bio-Trinkschokolade',
 'Flap Jack Apricot Smoothie',
 'Chai Latte, Classic India Weniger Süss Typ Vanille...']

This also makes sense I guess. So far so good.

Greenwashing?

Very frequently I can’t believe how a specific product got an “A” grade, like how does some chocolate get an “A” grade even though it is high of sugar. Is there any way that companies are cheating the system?

Let’s have a look at all the products that contain the word “choco” and see if we find something.

Code
x = "sugars_100g"
y = "fat_100g"
(
    alt.Chart(
        df.filter(pl.col("product_name").str.to_lowercase().str.contains("choco"))
        .sort("nutriscore_grade")
        .select(x, y, "nutriscore_grade")
    )
    .mark_point(filled=True, clip=True)
    .encode(
        x=alt.X(f"{x}:Q").scale(domain=(0, 100)),
        y=alt.Y(f"{y}:Q").scale(domain=(0, 100)),
        color=color,
    )
    .properties(width=300, height=300)
)

Hmm so this seems to suggest that “A” grades are actually properly assigned when sugar and fat content is low.

Code
x = "sugars_100g"
y = "fat_100g"
(
    alt.Chart(
        df.filter(pl.col("product_name").str.to_lowercase().str.contains("yoghurt"))
        .sort("nutriscore_grade")
        .select(x, y, "nutriscore_grade")
    )
    .mark_point(filled=True, clip=True)
    .encode(
        x=alt.X(f"{x}:Q").scale(domain=(0, 100)),
        y=alt.Y(f"{y}:Q").scale(domain=(0, 100)),
        color=color,
    )
    .properties(width=300, height=300)
)

How about “sweet snacks” (a label in the data frame) in general?

Code
x = "sugars_100g"
y = "fat_100g"
(
    alt.Chart(
        df.filter(
            pl.col("categories_en").str.to_lowercase().str.contains("sweet snacks")
        )
        .sort("nutriscore_grade")
        .select(x, y, "nutriscore_grade")
    )
    .mark_point(filled=True, clip=True, opacity=0.2)
    .encode(
        x=alt.X(f"{x}:Q").scale(domain=(0, 100)),
        y=alt.Y(f"{y}:Q").scale(domain=(0, 100)),
        color=color,
    )
    .properties(width=300, height=300)
)

What about vegan labels?

Code
x = "sugars_100g"
y = "fat_100g"
(
    alt.Chart(
        df.filter(
            pl.col("ingredients_analysis_tags").str.to_lowercase().str.contains("vegan")
        )
        .sort("nutriscore_grade")
        .select(x, y, "nutriscore_grade")
    )
    .mark_point(filled=True, clip=True)
    .encode(
        x=alt.X(f"{x}:Q").scale(domain=(0, 100)),
        y=alt.Y(f"{y}:Q").scale(domain=(0, 100)),
        color=color,
    )
    .properties(width=200, height=300)
    .facet(facet="nutriscore_grade:N", columns=3)
)

Next is protein vs. sugar.

Code
x = "sugars_100g"
y = "proteins_100g"
(
    alt.Chart(
        df.filter(pl.col("product_name").str.to_lowercase().str.contains("choco"))
        .sort("nutriscore_grade")
        .select(x, y, "nutriscore_grade")
    )
    .mark_point(filled=True, clip=True)
    .encode(
        x=alt.X(f"{x}:Q").scale(domain=(0, 100)),
        y=alt.Y(f"{y}:Q").scale(domain=(0, 100)),
        color=color,
    )
    .properties(width=200, height=300)
    .facet(facet="nutriscore_grade:N", columns=3)
)

So no greenwashing? Need to continue this.

Correlations

Code
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer
import collections
import numpy as np

Let’s see which of the nutrients correlate with the Nutriscore.

To do that, we pick all columns that contain ingredients about the nutrients and where at most 10% of the entries are null.

Code
d = (
    (
        df.select([c for c in df.columns if c.endswith("_100g")]).null_count()
        / len(df)
        * 100
    )
    <= 10
)[0].to_dict()
columns = ["nutriscore_score"] + [key for key, values in d.items() if values[0]]

display(
    df.select(columns)
    .to_pandas()
    .corr()[["nutriscore_score"]]
    .iloc[1:-1, :]
    .sort_values("nutriscore_score")
    .style.background_gradient(cmap="RdBu", vmin=-1, vmax=1)
)
  nutriscore_score
proteins_100g 0.077772
salt_100g 0.094111
sodium_100g 0.094118
carbohydrates_100g 0.146849
energy-kcal_100g 0.222867
energy_100g 0.336517
sugars_100g 0.370191
fat_100g 0.488456
saturated-fat_100g 0.578473

So (saturated) fat is the best indicator for a Nutriscore.

The obvious next question is, how well can we predict this. We’ll use simple Decision Trees.

Code
columns_for_fitting = [
    "energy-kcal_100g",
    "energy_100g",
    "fat_100g",
    "saturated-fat_100g",
    "carbohydrates_100g",
    "sugars_100g",
    "proteins_100g",
    "salt_100g",
    "sodium_100g",
    # "nutrition-score-fr_100g",
]

# columns_for_fitting =columns
Code
y = df.select("nutriscore_score").to_numpy().flatten()
logger.info(y.shape)

X = df.select(
    [
        c
        for c in columns_for_fitting
        if (c != "nutriscore_score" and c != "nutrition-score-fr_100g")
    ]
).to_numpy()
logger.info(X.shape)


logger.info("Keep only data that has no nans.")
import numpy as np

select = np.sum(np.isnan(X), axis=1) == 0
y = y[select]
X = X[select, :]
logger.info(y.shape)
logger.info(X.shape)


transformer = Normalizer().fit(X)

X_train, X_test, y_train, y_test = train_test_split(
    transformer.transform(X), y, test_size=0.80, random_state=2023
)


clf = tree.DecisionTreeClassifier(max_depth=15)
clf = clf.fit(X_train, y_train)


df_tree = pl.concat(
    [
        pl.DataFrame(
            {
                "actual score": y_test,
                "predicted score": clf.predict(X_test),
                "label": "test",
            }
        ),
        pl.DataFrame(
            {
                "actual score": y_train,
                "predicted score": clf.predict(X_train),
                "label": "train",
            }
        ),
    ]
).with_columns(err=pl.col("predicted score") - pl.col("actual score"))


chart_v1 = (
    alt.Chart(df_tree)
    .mark_point(filled=True, opacity=0.02)
    .encode(
        x="actual score:Q",
        y="predicted score:Q",
        # y="err:Q",
        color="label:N",
        column="label:N",
    )
    .properties(width=300, height=300)
)

display(chart_v1)

A little underwhelming. Maybe adding information on the categories helps. Let’s find the top 20 labels and one-hot-encode them.

Code
lol_categories = (
    df.select("categories_en")
    .with_columns(pl.col("categories_en").str.split(","))["categories_en"]
    .to_list()
)


categories = []
for l in lol_categories[:100]:
    categories.extend(l)


hist = dict(collections.Counter(categories))

df_hist = (
    pl.DataFrame(
        {
            "categorie": [key for key, _ in hist.items()],
            "count": [value for _, value in hist.items()],
        }
    )
    .sort("count", descending=True)
    .head(20)
)
print(df_hist["categorie"].to_list())
[
    'Plant-based foods and beverages',
    'Plant-based foods',
    'Beverages',
    'Cereals and potatoes',
    'Snacks',
    'Cereals and their products',
    'Sweet snacks',
    'Cocoa and its products',
    'Carbonated drinks',
    'Sodas',
    'Chocolates',
    'Colas',
    'Canned foods',
    'Meals',
    'Fruits and vegetables based foods',
    'Artificially sweetened beverages',
    'Sweetened beverages',
    'Seeds',
    'Cereal grains',
    'Canned plant-based foods'
]

Alright so we can now one-hot-encode these. Let’s run the classifier again, but now we add these labels. We can also check the correlation again.

Code
df_one_hot = df.with_columns(
    (pl.col("categories_en").str.to_lowercase().str.contains(x.lower()))
    .cast(int)
    .alias(f"1hot__{x}")
    for x in df_hist["categorie"].to_list()
)
columns_one_hot = [c for c in df_one_hot.columns if "1hot" in c]
logger.info(columns_one_hot)

y = df_one_hot.select("nutriscore_score").to_numpy().flatten()
logger.info(y.shape)

X = df_one_hot.select(
    [
        c
        for c in columns_for_fitting + columns_one_hot
        if (c != "nutriscore_score" and c != "nutrition-score-fr_100g")
    ]
).to_numpy()
logger.info(X.shape)


logger.info("Keep only data that has no nans.")

select = np.sum(np.isnan(X), axis=1) == 0
y = y[select]
X = X[select, :]
logger.info(y.shape)
logger.info(X.shape)

d = (
    (
        df_one_hot.select(
            [c for c in df_one_hot.columns if c.endswith("_100g")]
        ).null_count()
        / len(df)
        * 100
    )
    <= 10
)[0].to_dict()

columns = (
    ["nutriscore_score"]
    + [key for key, values in d.items() if values[0]]
    + columns_one_hot
)

display(
    df_one_hot.select(columns)
    .to_pandas()
    .corr()[["nutriscore_score"]]
    .iloc[1:-1, :]
    .sort_values("nutriscore_score")
    .style.background_gradient(cmap="RdBu", vmin=-1, vmax=1)
)


logger.info(X.shape)
transformer = Normalizer().fit(X)

X_train, X_test, y_train, y_test = train_test_split(
    transformer.transform(X), y, test_size=0.80, random_state=2023
)

clf = tree.DecisionTreeClassifier(max_depth=15)
clf = clf.fit(X_train, y_train)


df_tree_one_hot = pl.concat(
    [
        pl.DataFrame(
            {
                "actual score": y_test,
                "predicted score": clf.predict(X_test),
                "label": "test",
            }
        ),
        pl.DataFrame(
            {
                "actual score": y_train,
                "predicted score": clf.predict(X_train),
                "label": "train",
            }
        ),
    ]
).with_columns(err=pl.col("predicted score") - pl.col("actual score"))


new_chart = (
    alt.Chart(df_tree_one_hot)
    .mark_point(filled=True, opacity=0.02)
    .encode(
        x="actual score:Q",
        y="predicted score:Q",
        # y="err:Q",
        color="label:N",
        column="label:N",
    )
    .properties(width=300, height=300)
)


chart_v1 & new_chart
  nutriscore_score
1hot__Beverages -0.393359
1hot__Plant-based foods -0.386696
1hot__Plant-based foods and beverages -0.386696
1hot__Cereals and potatoes -0.274596
1hot__Cereals and their products -0.267379
1hot__Fruits and vegetables based foods -0.224242
1hot__Seeds -0.180736
1hot__Canned foods -0.158242
1hot__Cereal grains -0.135658
1hot__Meals -0.105492
1hot__Carbonated drinks -0.046306
1hot__Artificially sweetened beverages -0.041858
1hot__Sweetened beverages -0.017318
1hot__Colas -0.008903
1hot__Sodas -0.005889
proteins_100g 0.077772
salt_100g 0.094111
sodium_100g 0.094118
carbohydrates_100g 0.146849
energy-kcal_100g 0.222867
1hot__Chocolates 0.308169
energy_100g 0.336517
sugars_100g 0.370191
1hot__Cocoa and its products 0.387590
1hot__Snacks 0.444509
1hot__Sweet snacks 0.450395
fat_100g 0.488456
saturated-fat_100g 0.578473
nutrition-score-fr_100g 1.000000
Code
df_tree_full = pl.concat(
    [
        df_tree.with_columns(pl.lit("simple").alias("model")),
        df_tree_one_hot.with_columns(pl.lit("one_hot").alias("model")),
    ]
)
Code
alt.Chart(df_tree_full).mark_bar().encode(
    x=alt.X("err:Q").bin(step=2).axis(values=np.arange(-30, 35, 5)),
    y=alt.Y("count():Q"),
    color="label:N",
    row="model:N",
    column="label:N",
).properties(width=300, height=200).resolve_scale(y="independent")
Code
df_tree_full.groupby("label", "model").agg(
    err_min=pl.col("err").min(),
    err_mean=pl.col("err").mean(),
    err_median=pl.col("err").median(),
    err_std=pl.col("err").std(),
    err_max=pl.col("err").max(),
)
shape: (4, 7)
label model err_min err_mean err_median err_std err_max
str str i64 f64 f64 f64 i64
"test" "one_hot" -35 0.062399 0.0 5.08557 31
"test" "simple" -40 -0.654212 0.0 5.912326 43
"train" "simple" -30 -0.688398 0.0 4.409293 25
"train" "one_hot" -33 0.011566 0.0 3.923673 35

This looks fine for now, with the one-hot-encoded model, the error is mostly acceptable.

However, let’s try one more time, this time with classification instead of regression.

Code
df_one_hot = df.with_columns(
    (pl.col("categories_en").str.to_lowercase().str.contains(x.lower()))
    .cast(int)
    .alias(f"1hot__{x}")
    for x in df_hist["categorie"].to_list()
)
columns_one_hot = [c for c in df_one_hot.columns if "1hot" in c]

y = df_one_hot.select("nutriscore_grade").to_numpy().flatten()

X = df_one_hot.select(
    [
        c
        for c in columns_for_fitting + columns_one_hot
        if (c != "nutriscore_score" and c != "nutrition-score-fr_100g")
    ]
).to_numpy()


select = np.sum(np.isnan(X), axis=1) == 0
y = y[select]
X = X[select, :]
logger.info(y.shape)
logger.info(X.shape)

from sklearn.metrics import confusion_matrix


clf = tree.DecisionTreeClassifier(max_depth=10)
transformer = Normalizer().fit(X)
X_train, X_test, y_train, y_test = train_test_split(
    transformer.transform(X), y, test_size=0.80, random_state=2023
)
clf.fit(X_train, y_train)

y_test_predict = clf.predict(X_test)
y_train_predict = clf.predict(X_train)

logger.info("Training confusion matrix")
# cm = confusion_matrix(y_test, y_test_predict, labels=["A", "B", "C", "D", "E"])
cm = confusion_matrix(y_train, y_train_predict, labels=["A", "B", "C", "D", "E"])
"""Confusion matrix whose i-th row and j-th column entry indicates the number of samples with true label being i-th class and predicted label being j-th class."""
letter = {4: "A", 3: "B", 2: "C", 1: "D", 0: "E"}
display(
    pl.DataFrame(cm)
    .rename(
        mapping={
            f"column_{i}": f"predicts {letter}"
            for i, letter in zip([0, 1, 2, 3, 4], ["A", "B", "C", "D", "E"])
        }
    )
    .with_columns(
        pl.col("predicts A")
        .rank()
        .cast(int)
        .apply(lambda r: letter[r - 1])
        .alias("true label")
    )
    .select(
        "true label",
        "predicts A",
        "predicts B",
        "predicts C",
        "predicts D",
        "predicts E",
    )
    .to_pandas()
    .style.background_gradient()
)


logger.info("Testing confusion matrix")
cm = confusion_matrix(y_test, y_test_predict, labels=["A", "B", "C", "D", "E"])
# cm = confusion_matrix(y_train, y_train_predict, labels=["A", "B", "C", "D", "E"])
"""Confusion matrix whose i-th row and j-th column entry indicates the number of samples with true label being i-th class and predicted label being j-th class."""
letter = {4: "A", 3: "B", 2: "C", 1: "D", 0: "E"}
display(
    pl.DataFrame(cm)
    .rename(
        mapping={
            f"column_{i}": f"predicts {letter}"
            for i, letter in zip([0, 1, 2, 3, 4], ["A", "B", "C", "D", "E"])
        }
    )
    .with_columns(
        pl.col("predicts A")
        .rank()
        .cast(int)
        .apply(lambda r: letter[r - 1])
        .alias("true label")
    )
    .select(
        "true label",
        "predicts A",
        "predicts B",
        "predicts C",
        "predicts D",
        "predicts E",
    )
    .to_pandas()
    .style.background_gradient()
)
  true label predicts A predicts B predicts C predicts D predicts E
0 A 1677 123 215 34 3
1 B 274 976 432 143 13
2 C 89 179 2262 689 28
3 D 14 41 374 3388 123
4 E 4 12 153 345 1983
  true label predicts A predicts B predicts C predicts D predicts E
0 A 5814 850 1079 253 25
1 B 1440 3327 2078 665 67
2 C 490 1103 7390 3422 236
3 D 137 307 2182 12228 1034
4 E 46 107 677 1874 7465

Summary

This was a good first deep dive into the data. Nothing suspicous as of yet. No obvious green washing and prediction works reasonably well (as expected). To be continued.!

 
Cookie Preferences